{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4qxv4Sn9b8CE"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/keras_inference\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  </td>\n",
        "    <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/keras_inference.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/core/keras_inference.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/keras_inference.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PXNm5_p_oxMF"
      },
      "source": [
        "# Run Gemma with Keras\n",
        "\n",
        "Generating content, summarizing, and analysing content are just some of the tasks you can accomplish with Gemma open models. This tutorial shows you how to get started running Gemma using Keras, including generating text content with text and image input. [Keras](https://keras.io/) provides implementations for running Gemma and other models using JAX, PyTorch, and TensorFlow. If you're new to Keras, you might want to read [Getting started with Keras](https://keras.io/getting_started/) before you begin.\n",
        "\n",
        "Gemma 3 and later models support text and image input. Earlier versions of Gemma only support text input, except for some variants, including [PaliGemma](https://ai.google.dev/gemma/docs/setup)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QQ6W7NzRe1VM"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Before starting this tutorial, make sure you have completed the following steps:\n",
        "\n",
        "* Get access to Gemma on [kaggle.com](https://www.kaggle.com).\n",
        "* Select a Colab runtime with sufficient resources to run\n",
        "  the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
        "* Generate and configure a Kaggle username and API key.\n",
        "\n",
        "If you need help completing these steps, see the [Gemma setup](https://ai.google.dev/gemma/docs/setup) instructions. After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_gN-IVRC3dQe"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DrBoa_Urw9Vx"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z9oy3QUmXtSd"
      },
      "source": [
        "### Install Keras packages\n",
        "\n",
        "Install the Keras and KerasHub Python packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UcGLzDeQ8NwN"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U keras-hub\n",
        "!pip install -q -U keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pm5cVOFt5YvZ"
      },
      "source": [
        "### Select a backend\n",
        "\n",
        "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3) lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial. For this tutorial, configure the backend for JAX as it typically provides the better performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7rS7ryTs5wjf"
      },
      "outputs": [],
      "source": [
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"  # Or \"tensorflow\" or \"torch\".\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "599765c72722"
      },
      "source": [
        "### Import packages\n",
        "\n",
        "Import the Keras and KerasHub packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f2fa267d75bc"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import keras_hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZsxDCbLN555T"
      },
      "source": [
        "## Load model\n",
        "\n",
        "Keras provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). Download and configure a Gemma model using the `Gemma3CausalLM` class to build an end-to-end, causal language modeling implementation for Gemma 3 models. Create the model using the `from_preset()` method, as shown in the following code example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yygIK9DEIldp"
      },
      "outputs": [],
      "source": [
        "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\n",
        "    \"gemma3_instruct_4b\",\n",
        "    dtype=\"bfloat16\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XrAWvsU6pI0E"
      },
      "source": [
        "The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E-cSEjULUhST"
      },
      "source": [
        "Once you have the model downloaded, Use the `summary()` function to get more info about the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e5nEbTdApL7W"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma3_causal_lm_preprocessor\"</span>\n",
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Layer (type)                                                  </span>┃<span style=\"font-weight: bold\">                                   Config </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma3_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Gemma3Tokenizer</span>)                            │                      Vocab size: <span style=\"color: #00af00; text-decoration-color: #00af00\">262,144</span> │\n",
              "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n",
              "│ gemma3_image_converter (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Gemma3ImageConverter</span>)                 │                   Image size: (<span style=\"color: #00af00; text-decoration-color: #00af00\">896</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">896</span>) │\n",
              "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n",
              "<span style=\"font-weight: bold\">Model: \"gemma3_causal_lm_1\"</span>\n",
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Layer (type)                  </span>┃<span style=\"font-weight: bold\"> Output Shape              </span>┃<span style=\"font-weight: bold\">         Param # </span>┃<span style=\"font-weight: bold\"> Connected to               </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ images (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)           │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">896</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">896</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ vision_indices (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ vision_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)      │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma3_backbone               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2560</span>)         │   <span style=\"color: #00af00; text-decoration-color: #00af00\">4,299,915,632</span> │ images[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],              │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Gemma3Backbone</span>)              │                           │                 │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],        │\n",
              "│                               │                           │                 │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],           │\n",
              "│                               │                           │                 │ vision_indices[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],      │\n",
              "│                               │                           │                 │ vision_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">262144</span>)       │     <span style=\"color: #00af00; text-decoration-color: #00af00\">671,088,640</span> │ gemma3_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]      │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
              "<span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">4,299,915,632</span> (8.79 GB)\n",
              "<span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">4,299,915,632</span> (8.79 GB)\n",
              "<span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mModel: \"gemma3_causal_lm_1\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "gemma_lm.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "81KHdRYOrWYm"
      },
      "source": [
        "The output of the summary shows the models total number of trainable parameters.\n",
        "For purposes of naming the model, the embedding layer is not counted against the number of parameters."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ij73k0PfUhjE"
      },
      "source": [
        "Note: To run larger Gemma models with Google Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform inferences using [Kaggle](https://www.kaggle.com/code) notebooks or Google Cloud projects.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FOBW7piN5-sl"
      },
      "source": [
        "## Generate text with text\n",
        "\n",
        "Generate text with a text prompt with using `generate()` method of the Gemma model object you configured in the previous steps. The optional `max_length` argument specifies the maximum length of the generated sequence. The following code examples shows a few ways to prompt the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aae5GHrdpj2_"
      },
      "outputs": [],
      "source": [
        "gemma_lm.generate(\"what is keras in 3 bullet points?\", max_length=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mw5XkiHU11Ft"
      },
      "source": [
        "You can also provide batched prompts using a list as input:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xV6vs8_C2BGt"
      },
      "outputs": [],
      "source": [
        "gemma_lm.generate(\n",
        "    [\"what is keras in 3 bullet points?\",\n",
        "     \"The universe is\"],\n",
        "    max_length=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vVlCnY7Gvm7U"
      },
      "source": [
        "If you're running on JAX or TensorFlow backends, you should notice that the second `generate()` call returns an answer more quickly. This performance improvement is because each call to `generate()` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are faster."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CLZodRy8bqBQ"
      },
      "source": [
        "### Use a prompt template\n",
        "\n",
        "When building more complex requests or multi-turn chat interactions use a prompt template to structure your request. The following code creates a standard template for Gemma prompts:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "suAz3uOEb4rb"
      },
      "outputs": [],
      "source": [
        "PROMPT_TEMPLATE = \"\"\"<start_of_turn>user\n",
        "{question}\n",
        "<end_of_turn>\n",
        "<start_of_turn>model\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0VXrxEVScl_P"
      },
      "source": [
        "The following code shows how to use the template to format a simple request:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G6SuO3BdftM7"
      },
      "outputs": [],
      "source": [
        "question = \"\"\"\"what is keras in 3 bullet points?\"\"\"\n",
        "prompt = PROMPT_TEMPLATE.format(question=question)\n",
        "gemma_lm.generate(prompt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MaVWoSpo3XyY"
      },
      "source": [
        "### Optional: Try a different sampler\n",
        "\n",
        "You can control the generation strategy for model object by setting the `sampler` argument on `compile()`. By default, [`\"greedy\"`](https://keras.io/api/keras_nlp/samplers/greedy_sampler/#greedysampler-class) sampling will be used. As an experiment, try setting a [`\"top_k\"`](https://keras.io/api/keras_nlp/samplers/top_k_sampler/) strategy:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mx55VQpN4DAK"
      },
      "outputs": [],
      "source": [
        "gemma_lm.compile(sampler=\"top_k\")\n",
        "gemma_lm.generate(\"The universe is\", max_length=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-okKgK4LfO0f"
      },
      "source": [
        "While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability. You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see [Samplers](https://keras.io/api/keras_nlp/samplers/)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NvQ1ajNlCjU6"
      },
      "source": [
        "## Generate text with image data\n",
        "\n",
        "With Gemma 3 and later models, you can use images as part of a prompt to generate output. This capability allows you to use Gemma to interpret visual content or use images as data for content generation.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SDp_zfIYD4b3"
      },
      "source": [
        "### Create image loader function\n",
        "\n",
        "The following function loads an image file from a URL and tokenizes it for use in Gemma prompt:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WUgo0S__EX9O"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import PIL\n",
        "\n",
        "def read_image(url):\n",
        "    \"\"\"Reads image from URL as NumPy array.\"\"\"\n",
        "\n",
        "    image_path = keras.utils.get_file(origin=url)\n",
        "    image = PIL.Image.open(image_path)\n",
        "    image = np.array(image)\n",
        "    return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RBqnJeQReUH_"
      },
      "source": [
        "### Load image for a prompt\n",
        "\n",
        "Load the image and format the data so the model can process it. Use `read_image()` function defined in the previous section, as shown in the example code below:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UnIMXwl2evBe"
      },
      "outputs": [],
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "image = read_image(\n",
        "    \"https://ai.google.dev/gemma/docs/images/thali-indian-plate.jpg\"\n",
        ")\n",
        "plt.imshow(image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tTe6GUozMfja"
      },
      "source": [
        "<img src=\"/gemma/docs/images/thali-indian-plate.jpg\" />\n",
        "\n",
        "**Figure 1.** Image of Thali Indian food on a metal plate."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hq-RhOZMcgkY"
      },
      "source": [
        "### Run request with an image\n",
        "\n",
        "When prompting the Gemma model with image content, you use a specific string sequence, `<start_of_image>`, within your prompt to include the image as part of the prompt. Use a prompt template, such as the `PROMPT_TEMPLATE` string defined previously, to format your request as shown in the following prompt code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HtMiNmD-eESl"
      },
      "outputs": [],
      "source": [
        "question = \"\"\"Which cuisine is this: <start_of_image>? \\\n",
        "Identify the food items present. Which macros is the meal \\\n",
        "high and low on? Keep your answer short.\\\n",
        "\"\"\"\n",
        "\n",
        "gemma_lm.generate(\n",
        "    {\n",
        "        \"images\": image,\n",
        "        \"prompts\": PROMPT_TEMPLATE.format(question=question),\n",
        "    },\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fjUsHgumE9jD"
      },
      "source": [
        "If you are using a smaller GPU, and encountering out of memory (OOM) errors, you can set `max_images_per_prompt` and `sequence_length` to smaller values. The following code shows how to reduce sequence length to 768."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h1nZRvyNFMxt"
      },
      "outputs": [],
      "source": [
        "gemma_lm.preprocessor.max_images_per_prompt = 2\n",
        "gemma_lm.preprocessor.sequence_length = 768"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7E6h-doZfmmb"
      },
      "source": [
        "### Run requests with multiple images\n",
        "\n",
        "When using more than one image in a prompt, use multiple `<start_of_image>` tokens for each provided image, as shown in the following example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XTpFhba-gT5v"
      },
      "outputs": [],
      "source": [
        "dog_a = read_image(\"http://localhost/images/dog-a.jpg\")\n",
        "dog_b = read_image(\"http://localhost/images/dog-b.jpg\")\n",
        "\n",
        "question = \"\"\"I have two images:\n",
        "\n",
        "Dog A: <start_of_image>\n",
        "Dog B: <start_of_image>\n",
        "\n",
        "Which breeds are they? Tell me a bit about them. \\\n",
        "Keep it short.\\\n",
        "\"\"\"\n",
        "\n",
        "gemma_lm.generate(\n",
        "    {\n",
        "        \"images\": [dog_a, dog_b],\n",
        "        \"prompts\": PROMPT_TEMPLATE.format(question=question),\n",
        "    },\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jBrbTYasoo-J"
      },
      "source": [
        "## What's next\n",
        "\n",
        "In this tutorial, you learned how to generate text using Keras and Gemma. Here are a few suggestions for what to learn next:\n",
        "\n",
        "* Learn how to [finetune a Gemma model](https://ai.google.dev/gemma/docs/core/lora_tuning).\n",
        "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
        "* Learn about [Gemma integration with Vertex AI](https://ai.google.dev/gemma/docs/integrations/vertex)\n",
        "* Learn how to [use Gemma models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "keras_inference.ipynb",
      "toc_visible": true
    },
    "google": {
      "image_path": "/site-assets/images/marketing/gemma.png",
      "keywords": [
        "examples",
        "gemma",
        "python",
        "quickstart",
        "text"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
